/**
 * 2017年5月27日
 */
package cn.edu.bjtu.driver;

import java.util.Arrays;
import java.util.concurrent.TimeUnit;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import cn.edu.bjtu.configuration.TextCategorizationCNNConfig;
import cn.edu.bjtu.core.CGNetworkDesigner;
import cn.edu.bjtu.core.CNNDataSetIteratorProviderHandler;
import cn.edu.bjtu.core.CNNNetworkDesignHandler;
import cn.edu.bjtu.datasource.dsiter.LengthExceptionDetectCNNDataSetIterator;
import cn.edu.bjtu.datasource.lsp.FDLabeledSentenceExceptionSupportProvider;
import cn.edu.bjtu.model.TextCategorizationCNNModel;

/**
 * cn.edu.bjtu.driver.FDCNNNet  -w2v D:\textdata\mymodel\fdModel_100.txt D:\\textdata\\word2vecmodel\\fddata_train -test D:\\textdata\\word2vecmodel\\fddata_test -target D:\textdata\mymodel\fdcnn_100_all_8_1.bin -epoch 100
 * cn.edu.bjtu.driver.FDCNNNet -restore -w2v D:\textdata\mymodel\fdModel_100.txt  -target D:\textdata\mymodel\fdcnn_100_all_8_1.bin
 * @author Alex
 *
 */
public class FDCNNNet extends DriverSupport implements CGNetworkDesigner,CNNNetworkDesignHandler, CNNDataSetIteratorProviderHandler{
	TextCategorizationCNNModel model = null;
	protected Options getOptions(){
		Options options = new Options();  
		options.addOption("w2v", "word2vecfile", true, "word2vec model file");
		options.addOption("train", "trainfile", true, "train set to build the classifier");
		options.addOption("test", "testfile", true, "test set to help minize the network error");
		options.addOption("target","targetfile",true,"file to save the cnn network");
		options.addOption("epoch",true,"the epoch");
		options.addOption("restore",false,"whether restore or train a model");
		return options;
	}
	@Override
	public void runInternal(CommandLine cli) throws Exception {
		if(cli.hasOption("w2v")){
			config.setWORD2VEC_MODEL_PATH(cli.getOptionValue("w2v"));
		}
		
		if(cli.hasOption("target")){
			config.setCNN_NETWORK_SAVE_FILE(cli.getOptionValue("target"));
		}
		
		if(cli.hasOption("restore")){
			model = TextCategorizationCNNModel.get();
			model.restoreFromFile();
			String []docs = new String[]{
					"曹操是中国历史上一个有名的人物",
					"孔令辉于今年4月继任中国乒乓球队女队总教练一职，目前正带队在德国准备世乒赛的比赛。孔令辉在接受采访时表示：“中国队这次是有备而来的，队员有责任做好打翻身仗的准备。因为在亚洲锦标赛上我们准备得不太充分，特别是中国队在亚锦赛上丢掉女单冠军后，每一位队员都有责任做好打翻身仗的准备。我相信，中国参赛的5位女运动员如果发挥正常的话，都有机会冲击单打金牌。",
					"USCG的各型船没有性能平庸之辈，都很有特色鲜明，针对某一特定问题，性能先进。如小小的救援艇，几十年来就强调任务角度倾覆的“自扶正”功能号称永不沉没，截击艇速度能达到50节以上，日常巡逻艇都有舒适的驾驶舱，三级防弹，减震座椅，完善的航电设备，还有风扇滑行艇专用于冰区和滩涂作业。",
					"据澎湃新闻不完全统计，包括碧桂园、华夏幸福、绿城、绿地、融创、华侨城、雅居乐、阳光城等企业在特色小镇方面均有新的动作。其中，碧桂园（02007.HK）已经成功布局了5个科技小镇；绿地控股（600606.SH）则把特色小镇模式纳入2017年发展战略，将重点选择有大城市购买力溢出效应、有产业导入支撑的一二线重点城市远郊及周边，投资启动特色小镇大盘项目，计划重点围绕智慧健康城、文化旅游城两个题材，形成开发模型和产品系列。",
					"充斥学校的“感恩教育”是种语言、情感暴力，但如果真要说它是否会对孩子造成严重的不良影响，倒也不会。“哭过就忘”是大多数人的反应。有听过“感恩演讲”的同学说：“那药劲儿过得特别快。那天之后同学都没有再提过那天的表现和感恩的话题。如果再有这种教育机会可以自主选择，我肯定不会再去听了。因为，那个感动好像是真的，又好像是假的。",
					"新华社曼谷5月29日电（记者汪瑾）泰国副总理兼国防部长巴维29日表示，警方仍在对本月22日发生在曼谷一家军用医院的爆炸案进行调查，目前已经拘留了近50名涉嫌与此案有关的人员。巴维表示，这些被拘留的人员中包括部分医院工作人员，目前警方正在对他们进行问询和调查。巴维拒绝评论这起案件中是否存在政治动机。",
					"根据多路媒体报道，除了十年版手机之外，苹果还将推出其他两款手机，型号分别是iPhone 7s和iPhone 7s Plus。这两款手机将是去年手机的自然升级版本，将会继续使用古老的液晶屏幕，但是在内存、应用处理器、闪存等方面引入一些改动。另据悉，苹果内部正在设计两个版本的十年版手机，不过只有一个版本才会作为正式产品上市销售，显然，苹果采取了双方案备份的战略，如果第一个方案失败，将会用第二个方案量产手机。",
			};
			//[C11-Space, C15-Energy, C16-Electronics, C17-Communication, C19-Computer, C23-Mine, C29-Transport, C3-Art, C31-Enviornment, C32-Agriculture, C34-Economy, C35-Law, C36-Medical, C37-Military, C38-Politics, C39-Sports, C4-Literature, C5-Education, C6-Philosophy, C7-History]
			Arrays.asList(docs).forEach(x->{
				System.out.println(Arrays.toString(model.predictDocumentLabelString(x)));
			});
			
			return ;
		}
		if(cli.hasOption("train")){
			config.setTEXT_DIR_OR_FILE(cli.getOptionValue("train"));
		}
		if(cli.hasOption("test")){
			config.setTEST_DIR_OR_FILE(cli.getOptionValue("test"));
		}
		if(cli.hasOption("epoch")){
			config.setCNN_EPOCH(Integer.parseInt(cli.getOptionValue("epoch")));
			//config.setCNN_EPOCH(1);
		}
		//要不要划分数据集,如果切分可以写在这里
//		FileDataSetSplit fdss = new FileDataSetSplit.Builder().setInputFileOrDir(new File("path/to/fddataset")).build();
//		fdss.doSplit();
		
		model = TextCategorizationCNNModel.get();
		//model.restoreFromFile();
		//System.out.println(Arrays.toString(model.predictDocument("我是中国人,我非常热爱自己的祖国.")));
		//if(true)return ;
//		
		model.configNetworkDesignHandler(this);
		model.configDataSetIteratorHandlerProvider(this);
		//异步的,所以main线程要休息,让build先执行获取锁,然后主线程等build完成之后,调用saveModel保存
		model.buildNetworkModel();
		TimeUnit.SECONDS.sleep(100);
		model.saveModel();
	}
	@Override
	public DataSetIterator handleTrain(TextCategorizationCNNConfig config, WordVectors wv, TokenizerFactory tf,
			int batch, int senLen) throws Exception {
		FDLabeledSentenceExceptionSupportProvider fsd = new FDLabeledSentenceExceptionSupportProvider("D:\\textdata\\word2vecmodel\\fddata_train");
		return new LengthExceptionDetectCNNDataSetIterator.Builder()
				.tokenizerFactory(tf)
		        .sentenceProvider(fsd)
		        .wordVectors(wv)
		        .minibatchSize(batch)
		        .maxSentenceLength(senLen)
		        .useNormalizedWordVectors(false)
		        .build();
	}
	@Override
	public DataSetIterator handleTest(TextCategorizationCNNConfig config, WordVectors wv, TokenizerFactory tf,
			int batch, int senLen) throws Exception {
		FDLabeledSentenceExceptionSupportProvider fsd = new FDLabeledSentenceExceptionSupportProvider("D:\\textdata\\word2vecmodel\\fddata_test");
		return new LengthExceptionDetectCNNDataSetIterator.Builder()
				.tokenizerFactory(tf)
		        .sentenceProvider(fsd)
		        .wordVectors(wv)
		        .minibatchSize(batch)
		        .maxSentenceLength(senLen)
		        .useNormalizedWordVectors(false)
		        .build();
	}
	@Override
	public ComputationGraphConfiguration handleCGC(int embeddingWordVectorLength) {
		int vectorSize = embeddingWordVectorLength;
		int cnnLayerFeatureMaps = config.getCNNLayerFeatureMaps();
		ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
	            .weightInit(WeightInit.RELU)
	            .activation(Activation.LEAKYRELU)
	            .updater(Updater.ADAM)
	            .convolutionMode(ConvolutionMode.Same)      //This is important so we can 'stack' the results later
	            .regularization(true).l2(0.0001)
	            .learningRate(0.01)
	            .graphBuilder()
	            .addInputs("input")
	            .addLayer("cnn3", new ConvolutionLayer.Builder()
	                .kernelSize(3,vectorSize)
	                .stride(1,vectorSize)
	                .nIn(1)
	                .nOut(cnnLayerFeatureMaps)
	                .build(), "input")
	            .addLayer("cnn4", new ConvolutionLayer.Builder()
	                .kernelSize(4,vectorSize)
	                .stride(1,vectorSize)
	                .nIn(1)
	                .nOut(cnnLayerFeatureMaps)
	                .build(), "input")
	            .addLayer("cnn5", new ConvolutionLayer.Builder()
	                .kernelSize(5,vectorSize)
	                .stride(1,vectorSize)
	                .nIn(1)
	                .nOut(cnnLayerFeatureMaps)
	                .build(), "input")
	            .addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5")      //Perform depth concatenation
	            .addLayer("globalPool", new GlobalPoolingLayer.Builder()
	                .poolingType(this.config.getPoolingType())
	                .build(), "merge")
	            .addLayer("out", new OutputLayer.Builder()
	                .lossFunction(LossFunctions.LossFunction.MCXENT)
	                .activation(Activation.SOFTMAX)
	                .nIn(3*cnnLayerFeatureMaps)
	                .nOut(20)    //2 classes: positive or negative
	                .build(), "globalPool")
	            .setOutputs("out")
	            .build();
	 return config;
	}
	@Override
	public MultiLayerConfiguration handleMLN(int embeddingWordVectorLength) {
		return null;
	}

}
